import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from sklearn import preprocessing
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score
import sys,os
sys.path.append(r"/home/yh579/GAFM/GAFM/models")
from GAFM import train_GAFM
from Marvell import train_marvell
from Vanilla import train_vanilla
from MaxNorm import train_maxnorm
from bases import process_data,DataSet
from tqdm import trange
raw_df = pd.read_csv('/home/yh579/GAFM/GAFM/data/spam.csv')#default of credit card clients.csv
scaler = preprocessing.StandardScaler()
raw_df.iloc[:,:-1] = pd.DataFrame(scaler.fit_transform(raw_df.iloc[:,:-1]), columns = raw_df.iloc[:,:-1].columns)
train_loader,test_loader,features=process_data(raw_df)

#IMDB
# config = {
#     "batch_size":1028
# }
# from keras.datasets import imdb
# num_words=500
# (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=num_words)
# def vectorize_sequences(sequences, dimension=num_words):
#     results = np.zeros((len(sequences), dimension))
#     for i, sequence in enumerate(sequences):
#         results[i, sequence] = 1.
#     return results
#
# train_features = vectorize_sequences(train_data)
# test_features = vectorize_sequences(test_data)
# train_labels = np.asarray(train_labels).astype('float32')
# test_labels = np.asarray(test_labels).astype('float32')
# print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
#     train_labels.shape[0],train_labels.sum(), 100 * train_labels.sum() / train_labels.shape[0]))
# features=train_features
# train_dataset = DataSet(train_features,
#                         train_labels.astype(np.float64).reshape(-1, 1))
# train_loader = torch.utils.data.DataLoader(train_dataset,
#                                            batch_size=config["batch_size"],
#                                            shuffle=True)
#
# test_dataset = DataSet(test_features,
#                        test_labels.astype(np.float64).reshape(-1, 1))
# test_loader = torch.utils.data.DataLoader(test_dataset,
#                                           batch_size=config["batch_size"],
#                                           shuffle=True)


##Parameters
best=1
Epochs=300
repeats=5
lr=1e-4#1e-4
sigma=0.01


#4.2.1 GAFM random_fix
delta=0.05
print('Training GAFM random_fix...','Delta=',delta)
train_auc_list_GAFM01_random_fix=[]
test_auc_list_GAFM01_random_fix=[]
na_leak_auc_list_GAFM01_random_fix=[]
ma_leak_auc_list_GAFM01_random_fix=[]
cos_leak_auc_list_GAFM01_random_fix=[]
train_tvd_list_GAFM01_random_fix=[]
for i in trange(repeats):
  random.seed(i)
  train_auc_GAFM_random_fix,test_auc_GAFM_random_fix,train_tvd_GAFM_random_fix,na_leak_auc_GAFM_random_fix,ma_leak_auc_GAFM_random_fix,cos_leak_auc_GAFM_random_fix,splitnn_GAFM_random_fix,_,_,_=train_GAFM(
    Epochs=Epochs,delta=delta,features=features,train_loader=train_loader,test_loader=test_loader,sigma=sigma,regenerate=False,mode='random_fix',lr=lr,info=True,standardization=True)
  if max(train_auc_GAFM_random_fix,1-train_auc_GAFM_random_fix)>0.5:
    train_auc_list_GAFM01_random_fix.append(max(train_auc_GAFM_random_fix,1-train_auc_GAFM_random_fix))
    test_auc_list_GAFM01_random_fix.append(max(test_auc_GAFM_random_fix,1-test_auc_GAFM_random_fix))
    train_tvd_list_GAFM01_random_fix.append(train_tvd_GAFM_random_fix)
    na_leak_auc_list_GAFM01_random_fix.append(na_leak_auc_GAFM_random_fix)
    ma_leak_auc_list_GAFM01_random_fix.append(ma_leak_auc_GAFM_random_fix)
    cos_leak_auc_list_GAFM01_random_fix.append(cos_leak_auc_GAFM_random_fix)
    if na_leak_auc_GAFM_random_fix<best:
      best=na_leak_auc_GAFM_random_fix
      GAFM_model=splitnn_GAFM_random_fix

print('GAFM_random_fix Training AUC',(train_auc_list_GAFM01_random_fix))
print('GAFM_random_fix Testing AUC',(test_auc_list_GAFM01_random_fix))
print('GAFM_random_fix TVD',(train_tvd_list_GAFM01_random_fix))
print('GAFM_random_fix NA Leak AUC',(na_leak_auc_list_GAFM01_random_fix))
print('GAFM_random_fix MA Leak AUC',(ma_leak_auc_list_GAFM01_random_fix))
print('GAFM_random_fix Median Leak AUC',(cos_leak_auc_list_GAFM01_random_fix))

print('GAFM_random_fix Mean Training AUC',np.mean(train_auc_list_GAFM01_random_fix),np.std(train_auc_list_GAFM01_random_fix))
print('GAFM_random_fix Mean Testing AUC',np.mean(test_auc_list_GAFM01_random_fix),np.std(test_auc_list_GAFM01_random_fix))
print('GAFM_random_fix Mean TVD',np.mean(train_tvd_list_GAFM01_random_fix),np.std(train_tvd_list_GAFM01_random_fix))
print('GAFM_random_fix Mean NA Leak AUC',np.mean(na_leak_auc_list_GAFM01_random_fix),np.std(na_leak_auc_list_GAFM01_random_fix))
print('GAFM_random_fix Mean MA Leak AUC',np.mean(ma_leak_auc_list_GAFM01_random_fix),np.std(ma_leak_auc_list_GAFM01_random_fix))
print('GAFM_random_fix Mean Median Leak AUC',np.mean(cos_leak_auc_list_GAFM01_random_fix),np.std(cos_leak_auc_list_GAFM01_random_fix))


#1.Vanilla
print('Training Vanilla...')
train_auc_list_vanilla,test_auc_list_vanilla,train_tvd_list_vanilla,na_leak_auc_list_vanilla,ma_leak_auc_list_vanilla,cos_leak_auc_list_vanilla=[],[],[],[],[],[]
for i in trange(repeats):
  random.seed(i)
  train_auc,test_auc,train_tvd,na_leak_auc,ma_leak_auc,cos_leak_auc,splitnn=train_vanilla(Epochs=Epochs,features=features,train_loader=train_loader,test_loader=test_loader,lr=lr,info=True)
  train_auc_list_vanilla.append(train_auc)
  test_auc_list_vanilla.append(test_auc)
  train_tvd_list_vanilla.append(train_tvd)
  na_leak_auc_list_vanilla.append(na_leak_auc)
  ma_leak_auc_list_vanilla.append(ma_leak_auc)
  cos_leak_auc_list_vanilla.append(cos_leak_auc)
  if na_leak_auc<best:
    best=na_leak_auc
    vanilla_model=splitnn
print('Vanilla Training AUC',(train_auc_list_vanilla))
print('Vanilla Testing AUC',(test_auc_list_vanilla))
print('Vanilla TVD',(train_tvd_list_vanilla))
print('Vanilla NA Leak AUC',(na_leak_auc_list_vanilla))
print('Vanilla MA Leak AUC',(ma_leak_auc_list_vanilla))
print('Vanilla Cos Leak AUC',(cos_leak_auc_list_vanilla))

print('Vanilla Mean Training AUC',np.mean(train_auc_list_vanilla),np.std(train_auc_list_vanilla))
print('Vanilla Mean Testing AUC',np.mean(test_auc_list_vanilla),np.std(test_auc_list_vanilla))
print('Vanilla Mean TVD',np.mean(train_tvd_list_vanilla),np.std(train_tvd_list_vanilla))
print('Vanilla Mean NA Leak AUC',np.mean(na_leak_auc_list_vanilla),np.std(na_leak_auc_list_vanilla))
print('Vanilla Mean MA Leak AUC',np.mean(ma_leak_auc_list_vanilla),np.std(ma_leak_auc_list_vanilla))
print('Vanilla Mean Cos Leak AUC',np.mean(cos_leak_auc_list_vanilla),np.std(cos_leak_auc_list_vanilla))

#2. MaxNorm
print('Training MaxNorm...')
train_auc_list_maxnorm,test_auc_list_maxnorm,train_tvd_list_maxnorm,na_leak_auc_list_maxnorm,ma_leak_auc_list_maxnorm,cos_leak_auc_list_maxnorm=[],[],[],[],[],[]
for i in trange(repeats):
  random.seed(i)
  train_auc_maxnorm,test_auc_maxnorm,train_tvd_maxnorm,na_leak_auc_maxnorm,ma_leak_auc_maxnorm,cos_leak_auc_maxnorm,splitnn_maxnorm=train_maxnorm(Epochs=Epochs,features=features,train_loader=train_loader,test_loader=test_loader,lr=lr,info=True)
  train_auc_list_maxnorm.append(train_auc_maxnorm)
  test_auc_list_maxnorm.append(test_auc_maxnorm)
  train_tvd_list_maxnorm.append(train_tvd_maxnorm)
  na_leak_auc_list_maxnorm.append(na_leak_auc_maxnorm)
  ma_leak_auc_list_maxnorm.append(ma_leak_auc_maxnorm)
  cos_leak_auc_list_maxnorm.append(cos_leak_auc_maxnorm)
  if na_leak_auc_maxnorm<best:
    best=na_leak_auc_maxnorm
    maxnorm_model=splitnn_maxnorm
      
print('MaxNorm Training AUC',(train_auc_list_maxnorm))
print('MaxNorm Testing AUC',(test_auc_list_maxnorm))
print('MaxNorm TVD',(train_tvd_list_maxnorm))
print('MaxNorm NA Leak AUC',(na_leak_auc_list_maxnorm))
print('MaxNorm MA Leak AUC',(ma_leak_auc_list_maxnorm))
print('MaxNorm Median Leak AUC',(cos_leak_auc_list_maxnorm))

print('MaxNorm Mean Training AUC',np.mean(train_auc_list_maxnorm),np.std(train_auc_list_maxnorm))
print('MaxNorm Mean Testing AUC',np.mean(test_auc_list_maxnorm),np.std(test_auc_list_maxnorm))
print('MaxNorm Mean TVD',np.mean(train_tvd_list_maxnorm),np.std(train_tvd_list_maxnorm))
print('MaxNorm Mean NA Leak AUC',np.mean(na_leak_auc_list_maxnorm),np.std(na_leak_auc_list_maxnorm))
print('MaxNorm Mean MA Leak AUC',np.mean(ma_leak_auc_list_maxnorm),np.std(ma_leak_auc_list_maxnorm))
print('MaxNorm Mean Median Leak AUC',np.mean(cos_leak_auc_list_maxnorm),np.std(cos_leak_auc_list_maxnorm))


#3. Marvell
print('Training GAFM Marvell..')
train_auc_list_marvell, test_auc_list_marvell, train_tvd_list_marvell, na_leak_auc_list_marvell, ma_leak_auc_list_marvell, cos_leak_auc_list_marvell = [], [], [], [], [], []
for i in trange(repeats):
  random.seed(i)
  try:
    train_auc_marvell, test_auc_marvell, train_tvd_marvell, na_leak_auc_marvell, ma_leak_auc_marvell, cos_leak_auc_marvell, splitnn_marvell = train_marvell(
      Epochs=Epochs,features=features,train_loader=train_loader,test_loader=test_loader,lr=lr, info=True)
  except RuntimeError:
    try:
      train_auc_marvell, test_auc_marvell, train_tvd_marvell, na_leak_auc_marvell, ma_leak_auc_marvell, cos_leak_auc_marvell, splitnn_marvell = train_marvell(
        Epochs=Epochs,features=features,train_loader=train_loader,test_loader=test_loader,lr=lr, info=True)
    except RuntimeError:
      try:
        train_auc_marvell, test_auc_marvell, train_tvd_marvell, na_leak_auc_marvell, ma_leak_auc_marvell, cos_leak_auc_marvell, splitnn_marvell = train_marvell(
          Epochs=Epochs,features=features,train_loader=train_loader,test_loader=test_loader,lr=lr, info=True)
      except RuntimeError:
        train_auc_marvell, test_auc_marvell, train_tvd_marvell, na_leak_auc_marvell, ma_leak_auc_marvell, cos_leak_auc_marvell, splitnn_marvell = train_marvell(
          Epochs=Epochs,features=features,train_loader=train_loader,test_loader=test_loader,lr=lr, info=True)

  # train_auc_marvell,test_auc_marvell,na_leak_auc_marvell,ma_leak_auc_marvell,cos_leak_auc_marvell,splitnn_marvell=train_marvell(Epochs=300,info=True)
  train_auc_list_marvell.append(train_auc_marvell)
  test_auc_list_marvell.append(test_auc_marvell)
  train_tvd_list_marvell.append(train_tvd_marvell)
  na_leak_auc_list_marvell.append(na_leak_auc_marvell)
  ma_leak_auc_list_marvell.append(ma_leak_auc_marvell)
  cos_leak_auc_list_marvell.append(cos_leak_auc_marvell)
  if na_leak_auc_marvell < best:
    best = na_leak_auc_marvell
    marvell_model = splitnn_marvell

print('Marvell Training AUC', (train_auc_list_marvell))
print('Marvell Testing AUC', (test_auc_list_marvell))
print('Marvell TVD', (train_tvd_list_marvell))
print('Marvell NA Leak AUC', (na_leak_auc_list_marvell))
print('Marvell MA Leak AUC', (ma_leak_auc_list_marvell))
print('Marvell Median Leak AUC', (cos_leak_auc_list_marvell))

print('MarvellMean Training AUC', np.mean(train_auc_list_marvell), np.std(train_auc_list_marvell))
print('MarvellMean Testing AUC', np.mean(test_auc_list_marvell), np.std(test_auc_list_marvell))
print('MarvellMean TVD', np.mean(train_tvd_list_marvell), np.std(train_tvd_list_marvell))
print('MarvellMean NA Leak AUC', np.mean(na_leak_auc_list_marvell), np.std(na_leak_auc_list_marvell))
print('MarvellMean MA Leak AUC', np.mean(ma_leak_auc_list_marvell), np.std(ma_leak_auc_list_marvell))
print('MarvellMean Median Leak AUC', np.mean(cos_leak_auc_list_marvell), np.std(cos_leak_auc_list_marvell))

